# Alex Gazmararian
# agazmararian@gmail.com

source("analysis/visibility/analysis_config.R")
source(here("R", "visibility", "replication_mode.R"))

# Initialize replication mode (this script uses state-clustered SEs so works in both modes)
REPLICATION_MODE <- init_replication_mode()
message("=== PROXIMITY HETEROGENEITY ANALYSIS ===")
message("Replication mode: ", REPLICATION_MODE)
message("(This analysis uses state-clustered SEs, no coordinate data required)")

g <- readRDS(here("data", "output", "visibility_analysis.rds"))

# Helper functions for heterogeneity table creation----

#' Build coefficient map for interaction models
#' @param varnames Character vector of treatment variables 
#' @param hetvar Character, name of moderating variable
#' @param hetvar_labels Character vector of labels for moderating variable values (optional)
#' @param data Input data frame to get unique hetvar values
#' @return Named character vector for coefficient mapping
build_interaction_coefmap <- function(varnames, hetvar, hetvar_labels, data) {
  
  # Start with distance quintile main effects
  distance_terms <- character(0)
  
  # Always include the main proximity effects for the variables being analyzed
  for (varname in varnames) {
    base_var <- gsub("_q$", "", varname)
    for (quintile in c("Q1", "Q2", "Q3", "Q4")) {
      distance_name <- paste0(base_var, "_q", quintile)
      distance_label <- paste0(quintile, " proximity")
      distance_terms[distance_name] <- distance_label
    }
  }
  
  # Also include any additional distance terms from global coefmap if they exist
  if (exists("coefmap", envir = .GlobalEnv)) {
    base_coefmap <- get("coefmap", envir = .GlobalEnv)
    # Filter to distance quintile terms that match our variables
    additional_distance_terms <- base_coefmap[grepl("Q[1-4]$", names(base_coefmap))]
    additional_distance_terms <- additional_distance_terms[grepl(paste(gsub("_q$", "", varnames), collapse = "|"), names(additional_distance_terms))]
    # Add any additional terms that aren't already included
    for (term_name in names(additional_distance_terms)) {
      if (!term_name %in% names(distance_terms)) {
        distance_terms[term_name] <- additional_distance_terms[term_name]
      }
    }
  }
  
  # Get unique values of hetvar from data
  # For factors, use levels() to get proper order and reference level
  if (is.factor(data[[hetvar]])) {
    unique_hetvar_vals <- levels(data[[hetvar]])
    reference_level <- unique_hetvar_vals[1]  # First level is reference
  } else {
    # For character variables, R converts to factor with alphabetical levels
    # So we need to match R's behavior
    unique_hetvar_vals <- unique(data[[hetvar]])
    unique_hetvar_vals <- unique_hetvar_vals[!is.na(unique_hetvar_vals)]
    # Remove any problematic values (empty strings, etc.)
    unique_hetvar_vals <- unique_hetvar_vals[unique_hetvar_vals != "" & !is.null(unique_hetvar_vals)]
    
    # If character, sort alphabetically to match R's automatic factor conversion
    if (is.character(unique_hetvar_vals)) {
      unique_hetvar_vals <- sort(unique_hetvar_vals)
    }
    reference_level <- unique_hetvar_vals[1]  # First value is reference
  }
  
  # If hetvar_labels provided, create mapping
  if (!is.null(hetvar_labels) && length(hetvar_labels) == length(unique_hetvar_vals)) {
    hetvar_map <- setNames(hetvar_labels, unique_hetvar_vals)
  } else {
    hetvar_map <- setNames(unique_hetvar_vals, unique_hetvar_vals)
  }
  
  # Build interaction term mappings and main effects for hetvar
  interaction_terms <- character(0)
  
  # Add main effects for the heterogeneity variable (non-reference levels)
  for (hetvar_val in unique_hetvar_vals) {
    # Skip reference category
    if (hetvar_val == reference_level) next
    
    # Skip if hetvar_val is problematic
    if (is.na(hetvar_val) || is.null(hetvar_val)) next
    if (is.character(hetvar_val) && hetvar_val == "") next
    
    # Create main effect term name
    # For binary variables (0/1), R uses just the variable name, not hetvar_value
    if (is.numeric(hetvar_val) && hetvar_val == 1 && length(unique_hetvar_vals) == 2) {
      main_effect_name <- hetvar
    } else {
      main_effect_name <- paste0(hetvar, hetvar_val)
    }
    
    # Create clean label
    if (!is.null(hetvar_labels) && length(hetvar_labels) == length(unique_hetvar_vals)) {
      # Find the position of this value and use corresponding label
      val_position <- which(unique_hetvar_vals == hetvar_val)
      if (val_position <= length(hetvar_labels)) {
        hetvar_label <- hetvar_labels[val_position]
      } else {
        hetvar_label <- as.character(hetvar_val)
      }
    } else {
      hetvar_label <- as.character(hetvar_val)
    }
    
    interaction_terms[main_effect_name] <- hetvar_label
  }
  
  for (varname in varnames) {
    # Get the base variable name without _q suffix
    base_var <- gsub("_q$", "", varname)
    
    # For each quintile
    for (quintile in c("Q1", "Q2", "Q3", "Q4")) {
      
      # Handle both binary and categorical variables the same way
      # Always create interactions for each non-reference level
      for (hetvar_val in unique_hetvar_vals) {
        # Skip reference category
        if (hetvar_val == reference_level) next
        
        # Skip if hetvar_val is problematic
        if (is.na(hetvar_val) || is.null(hetvar_val)) next
        if (is.character(hetvar_val) && hetvar_val == "") next
        
        # Create interaction term name
        # For binary variables (0/1), R uses just the variable name, not hetvar_value
        if (is.numeric(hetvar_val) && hetvar_val == 1 && length(unique_hetvar_vals) == 2) {
          interaction_name <- paste0(base_var, "_q", quintile, ":", hetvar)
        } else {
          interaction_name <- paste0(base_var, "_q", quintile, ":", hetvar, hetvar_val)
        }
        
        # Create clean label
        if (!is.null(hetvar_labels) && length(hetvar_labels) == length(unique_hetvar_vals)) {
          # Find the position of this value and use corresponding label
          val_position <- which(unique_hetvar_vals == hetvar_val)
          if (val_position <= length(hetvar_labels)) {
            hetvar_label <- hetvar_labels[val_position]
          } else {
            hetvar_label <- as.character(hetvar_val)
          }
        } else {
          hetvar_label <- as.character(hetvar_val)
        }
        
        interaction_label <- paste0(quintile, " $\\times$ ", hetvar_label)
        interaction_terms[interaction_name] <- interaction_label
      }
    }
  }
  
  # Combine main effects and interactions
  combined_coefmap <- c(distance_terms, interaction_terms)
  
  return(combined_coefmap)
}

#' Organize models for heterogeneity table
#' @param all_models List of all estimated models
#' @param varnames Character vector of treatment variables
#' @param outcomes Character vector of outcome variables
#' @param outcome_labels Character vector of outcome labels
#' @param treatment_labels Character vector of treatment labels
#' @return Named list of models organized for table
organize_het_models <- function(all_models, varnames, outcomes, outcome_labels, treatment_labels) {
  
  table_models <- list()
  
  # Create shorter outcome labels for table headers
  short_outcome_labels <- c(
    "visibility probability" = "Visibility",
    "Biden credit attribution" = "Credit", 
    "perceived benefits" = "Benefit"
  )
  
  # Organize models: treatments as major groups, outcomes within each treatment
  for (j in seq_along(varnames)) {
    for (i in seq_along(outcomes)) {
      model_index <- (i - 1) * length(varnames) + j
      model <- all_models[[model_index]]
      
      # Use short outcome labels for cleaner headers
      short_label <- if (outcome_labels[i] %in% names(short_outcome_labels)) {
        short_outcome_labels[[outcome_labels[i]]]
      } else {
        # Fallback: capitalize first word
        tools::toTitleCase(strsplit(outcome_labels[i], " ")[[1]][1])
      }
      
      # Create unique model name for internal use, but modelsummary will use these as column headers
      model_name <- paste0(treatment_labels[j], "_", short_label)
      table_models[[model_name]] <- model
    }
  }
  
  return(table_models)
}

#' Create comprehensive heterogeneity table
#' @param models Named list of models
#' @param coef_map Coefficient mapping
#' @param hetvar Name of moderating variable
#' @param hetvar_labels Labels for moderating variable
#' @param treatment_labels Treatment labels
#' @param outcome_labels Outcome labels  
#' @param filename Output filename
#' @param output_prefix Output prefix for title
#' @param table_resize_width Numeric, manual resize width for tables (NULL uses automatic logic, 0 disables resizing)
create_heterogeneity_table <- function(models, coef_map, hetvar, hetvar_labels, treatment_labels, 
                                     outcome_labels, filename, output_prefix, table_resize_width = NULL) {
  
  # Create mapping for cleaner variable names
  hetvar_name_map <- c(
    "college" = "education level",
    "pid3" = "partisanship", 
    "income_bin_lab" = "household income",
    "status_mfg_operating" = "manufacturing project status",
    "status_eia_d_re_2y" = "renewable energy project status",
    "popd_abovemed" = "population density",
    "sector_d_mfg_open_2y" = "manufacturing sector",
    "renewable_moderator" = "renewable energy technology",
    "re_1ybefore" = "renewable energy project timing",
    "mfg_1ybefore" = "manufacturing project timing",
    "govparty" = "governor party affiliation",
    "swing" = "swing state status",
    "gives_statement_biden_d_mfg_open_2y" = "Biden statement presence",
    "biden_dropped_out" = "Biden campaign exit timing"
  )
  
  # Create title and notes
  if (!is.null(hetvar_labels)) {
    hetvar_title <- paste(hetvar_labels, collapse = " vs. ")
  } else if (hetvar %in% names(hetvar_name_map)) {
    hetvar_title <- hetvar_name_map[[hetvar]]
  } else {
    # Fallback: clean up the variable name
    hetvar_title <- gsub("_", " ", hetvar)
    hetvar_title <- gsub("\\bd\\b", "distance to", hetvar_title)
    hetvar_title <- gsub("\\bmfg\\b", "manufacturing", hetvar_title)
    hetvar_title <- gsub("\\bre\\b", "renewable energy", hetvar_title)
    hetvar_title <- tools::toTitleCase(hetvar_title)
  }
  
  title <- paste0("Heterogeneous proximity effects by ", tolower(hetvar_title))
  
  # Build outcome descriptions dynamically based on outcome_labels used
  outcome_descriptions <- c(
    "visibility probability" = "Visibility = 1 if respondent reports a local green project, 0 otherwise.",
    "Biden credit attribution" = "Credit = 1 if respondent credits the Biden Administration for local green investments.",
    "perceived benefits" = "Benefit = 1 if respondent perceives a benefit from local green projects."
  )
  
  # Only include descriptions for outcomes actually used
  outcome_notes <- paste(outcome_descriptions[outcome_labels], collapse = " ")
  
  notes <- paste0(
    "\\textit{Notes:} ",
    "Each column reports a separate linear probability model with interactions between proximity and ", tolower(hetvar_title), ". ",
    "Unit of analysis is the individual survey respondent. ",
    outcome_notes, " ",
    "Estimates are OLS with cluster-robust standard errors by state in parentheses. ",
    "$^{*}p<0.05$, $^{**}p<0.01$, $^{***}p<0.001$."
  )
  
  # Create grouped column headers by treatment type
  n_models <- length(models)
  n_treatments <- length(treatment_labels)
  
  # Build group structure
  # Models are organized as: Treatment1_Outcome1, Treatment1_Outcome2, ..., Treatment2_Outcome1, Treatment2_Outcome2, ...
  n_outcomes <- n_models / n_treatments
  group_list <- list()
  for (i in seq_along(treatment_labels)) {
    # Calculate column range for this treatment
    start_col <- (i - 1) * n_outcomes + 2  # +1 for first column (variable names), +1 for 1-based indexing
    end_col <- i * n_outcomes + 1
    treatment_cols <- start_col:end_col
    group_list[[treatment_labels[i]]] <- treatment_cols
  }
  
  # Get global gm if available
  if (exists("gm", envir = .GlobalEnv)) {
    gm_use <- get("gm", envir = .GlobalEnv)
  } else {
    gm_use <- modelsummary::gof_map
    gm_use$omit <- TRUE
    gm_use[gm_use$raw == "adj.r.squared", ]$clean <- "Adjusted $R^2$"
    gm_use[gm_use$raw == "adj.r.squared", ]$omit <- FALSE
    gm_use[gm_use$raw == "nobs", ]$clean <- "$N$"
    gm_use[gm_use$raw == "nobs", ]$omit <- FALSE
  }
  
  # Create clean column labels
  clean_labels <- gsub(".*_", "", names(models))
  
  # Create table
  table_out <- modelsummary::modelsummary(
    models,
    title = paste0(title, " \\label{tab:proximity_effect_het_", output_prefix, "}"),
    notes = notes,
    coef_map = coef_map,
    stars = c("*" = 0.05, "**" = 0.01, "***" = 0.001),
    add_rows = data.frame(
      term = c("Covariates", "State Fixed Effects"),
      matrix("Yes", nrow = 2, ncol = n_models, 
             dimnames = list(NULL, names(models)))
    ),
    gof_omit = "IC|RMSE|Std|FE|Within",
    gof_map = gm_use,
    output = "tinytable",
    escape = FALSE,
    fmt = modelsummary::fmt_significant(2)
  ) %>%
    tinytable::group_tt(j = group_list)
  
  # Set clean column headers
  colnames(table_out) <- c("", clean_labels)
  
  # Apply conditional formatting based on number of models
  if (!is.null(table_resize_width)) {
    # Use manual resize width if provided
    if (table_resize_width == 0) {
      # No width resizing if set to 0
      table_out <- table_out %>%
        tinytable::theme_latex(placement = "H") %>%
        tinytable::save_tt(filename, overwrite = TRUE)
    } else {
      # Use specified resize width
      table_out <- table_out %>%
        tinytable::theme_latex(resize_width = table_resize_width, resize_direction = "both", placement = "H") %>%
        tinytable::save_tt(filename, overwrite = TRUE)
    }
  } else {
    # Use default logic based on number of models
    if (n_models <= 3) {
      # No width resizing for small tables
      table_out <- table_out %>%
        tinytable::theme_latex(placement = "H") %>%
        tinytable::save_tt(filename, overwrite = TRUE)
    } else {
      # Width resizing for larger tables
      table_out <- table_out %>%
        tinytable::theme_latex(resize_width = .9, resize_direction = "both", placement = "H") %>%
        tinytable::save_tt(filename, overwrite = TRUE)
    }
  }
  
  message("Heterogeneity table saved to: ", filename)
  return(table_out)
}

# Wrapper function for treatment effect heterogeneity analysis----
#' Plot treatment effect heterogeneity across outcomes and moderating variables
#' @param data Input data frame
#' @param varnames Character vector of treatment variables (default: c("d_mfg_open_q", "d_re_q"))
#' @param hetvar Character, name of moderating variable
#' @param outcomes Character vector of outcomes (default: c("greenproj_bin", "credit_biden_bin", "greenbenefit_bin"))
#' @param outcome_labels Character vector of outcome labels for plots (default: c("visibility probability", "Biden credit attribution", "perceived benefits"))
#' @param treatment_labels Character vector of treatment labels for plot titles (default: c("Manufacturing", "Renewable generation"))
#' @param hetvar_labels Character vector of labels for moderating variable values (optional)
#' @param color_scale Character or function, color scale to use ("default", "party", or custom function)
#' @param legend_label Character, custom label for the legend (default: uses hetvar name)
#' @param combine_plots Logical, whether to combine plots with patchwork (default: TRUE)
#' @param save_plots Logical, whether to save plots automatically (default: TRUE)
#' @param output_prefix Character, prefix for output filenames (required if save_plots=TRUE)
#' @param fig_number Character, figure number for supplementary appendix (e.g., "S5" produces "fig_S5_..."). Default NULL maintains current behavior.
#' @param tab_number Character, table number for supplementary appendix (e.g., "S3" produces "tab_S3_..."). Default NULL maintains current behavior.
#' @param width Character, plot width for save_pnas_pdf ("single", "double", "intermediate")
#' @param height Numeric, plot height in cm
#' @param table_resize_width Numeric, manual resize width for tables (NULL uses automatic logic, 0 disables resizing)
#' @return List containing individual plots and combined plot (if combine_plots=TRUE)
plot_het_effects <- function(data, 
                            varnames = treat.vars,
                            hetvar,
                            outcomes = c("greenproj_bin", "credit_biden_bin", "greenbenefit_bin"),
                            outcome_labels = c("visibility probability", "Biden credit attribution", "perceived benefits"),
                            treatment_labels = tools::toTitleCase(gsub("_", " ", treat.labels)),
                            hetvar_labels = NULL,
                            color_scale = "default",
                            legend_label = NULL,
                            combine_plots = TRUE,
                            save_plots = TRUE,
                            output_prefix = NULL,
                            fig_number = NULL,
                            tab_number = NULL,
                            width = "double",
                            height = 14,
                            table_resize_width = NULL) {
  
  # Input validation
  if (length(varnames) != length(treatment_labels)) {
    stop("varnames and treatment_labels must have same length")
  }
  if (length(outcomes) != length(outcome_labels)) {
    stop("outcomes and outcome_labels must have same length")
  }
  if (save_plots && is.null(output_prefix)) {
    stop("output_prefix required when save_plots=TRUE")
  }
  
  # Subset data to only required columns

  # Base columns needed for analysis (uses state-clustered SEs)
  select_cols <- c(varnames, covar.list, hetvar, outcomes, "state")
  
  # In full mode, also include coordinate columns for consistency with main dataset
  # (not used by this analysis since se = "state", but keeps data structure intact)
  if (REPLICATION_MODE == "full") {
    select_cols <- c(select_cols, "lat_zip", "lon_zip")
  }
  
  data <- data %>%
    select(all_of(select_cols))

  # Initialize results list
  results <- list()
  
  # First pass: collect all effect estimates and models to determine y-axis limits
  all_effects <- list()
  all_models <- list()
  effect_counter <- 1
  
  for (i in seq_along(outcomes)) {
    outcome <- outcomes[i]
    for (j in seq_along(varnames)) {
      varname <- varnames[j]
      message("Processing: ", outcome, " ~ ", varname, " by ", hetvar)
      
      # Force garbage collection to free memory
      gc()
      
      model <- est_model(varname = varname, hetvar = hetvar, outcome = outcome, data.in = data, fe = "| state", se = "state")
      effects_df <- model %>%
        marginaleffects::avg_slopes(variables = varname, by = hetvar) %>%
        data.frame()
      all_effects[[effect_counter]] <- effects_df
      all_models[[effect_counter]] <- model
      effect_counter <- effect_counter + 1
    }
  }
  
  # Calculate overall y-axis limits
  all_estimates <- do.call(rbind, all_effects)
  y_min <- min(all_estimates$conf.low, na.rm = TRUE) * 1.1  # Add 10% padding
  y_max <- max(all_estimates$conf.high, na.rm = TRUE) * 1.1
  
  # Second pass: create plots with consistent y-axis limits
  # Loop through outcomes
  for (i in seq_along(outcomes)) {
    outcome <- outcomes[i]
    outcome_label <- outcome_labels[i]
    
    # Initialize list for this outcome
    outcome_plots <- list()
    
    # Loop through treatment variables
    for (j in seq_along(varnames)) {
      varname <- varnames[j]
      treatment_label <- treatment_labels[j]
      
      # Use pre-calculated effects from first pass
      effects_df <- all_effects[[(i-1)*length(varnames) + j]]
      
      # Apply hetvar labels if provided
      if (!is.null(hetvar_labels)) {
        if (is.factor(effects_df[[hetvar]])) {
          levels(effects_df[[hetvar]]) <- hetvar_labels
        } else {
          # Create mapping for non-factor variables
          unique_vals <- unique(effects_df[[hetvar]])
          if (length(unique_vals) == length(hetvar_labels)) {
            effects_df[[hetvar]] <- factor(effects_df[[hetvar]], 
                                         levels = unique_vals, 
                                         labels = hetvar_labels)
          }
        }
      }
      
      # Create base plot with color and shape variations
      p <- effects_df %>%
        ggplot(aes(x = contrast, y = estimate, ymin = conf.low, ymax = conf.high, 
                  color = factor(.data[[hetvar]]), shape = factor(.data[[hetvar]]))) +
        geom_hline(yintercept = 0, color = "grey") +
        geom_pointrange(size = 1, linewidth = 1, position = position_dodge(width = 0.5)) +
        theme_classic(base_size = 10) +
        labs(
          x = "Distance quintiles (reference: Q5 = farthest group)",
          y = paste("Effect on", outcome_label),
          title = treatment_label,
          color = if(!is.null(legend_label)) legend_label else tools::toTitleCase(hetvar),
          shape = if(!is.null(legend_label)) legend_label else tools::toTitleCase(hetvar)
        ) +
        theme(legend.position = "bottom") +
        # Apply consistent y-axis limits
        scale_y_continuous(limits = c(y_min, y_max))
      
      # Apply color scale
      if (is.character(color_scale) && color_scale == "party" && hetvar == "pid3") {
        p <- p + scale_color_party()
      } else if (is.function(color_scale)) {
        p <- p + color_scale()
      } else if (color_scale == "default") {
        # Use greyscale as default
        p <- p + scale_color_grey(start = 0.2, end = 0.7)
      }
      # Add manual shape scale for better distinction
      p <- p + scale_shape_manual(values = c(16, 17, 15, 18, 8, 4, 3, 7, 10, 12))
      
      # Store plot
      outcome_plots[[paste0(varname, "_", outcome)]] <- p
    }
    
    # Store plots for this outcome
    results[[outcome]] <- outcome_plots
  }
  
  # Create mega-plot combining all outcomes and treatments
  if (combine_plots && length(results) > 0) {
    all_plots <- list()
    plot_counter <- 1
    
    # Organize plots by outcome (rows) and treatment (columns)
    for (i in seq_along(outcomes)) {
      outcome <- outcomes[i]
      outcome_label <- outcome_labels[i]
      
      for (j in seq_along(varnames)) {
        varname <- varnames[j]
        treatment_label <- treatment_labels[j]
        
        plot_key <- paste0(varname, "_", outcome)
        if (plot_key %in% names(results[[outcome]])) {
          # Add titles directly to each panel
          panel_title <- if (length(varnames) > 1 && length(outcomes) > 1) {
            # Multi-treatment, multi-outcome: show both
            paste(treatment_label, "-", outcome_labels[i])
          } else if (length(varnames) > 1) {
            # Multi-treatment, single outcome: show treatment only
            treatment_label
          } else if (length(outcomes) > 1) {
            # Single treatment, multi-outcome: show outcome only
            outcome_labels[i]
          } else {
            # Single treatment, single outcome: no title needed
            ""
          }
          
          if (panel_title != "") {
            results[[outcome]][[plot_key]] <- results[[outcome]][[plot_key]] +
              labs(title = panel_title) +
              theme(plot.title = element_text(hjust = 0.5, size = 11, face = "bold"))
          } else {
            results[[outcome]][[plot_key]] <- results[[outcome]][[plot_key]] +
              labs(title = "") +
              theme(plot.title = element_blank())
          }
          
          all_plots[[plot_counter]] <- results[[outcome]][[plot_key]]
          plot_counter <- plot_counter + 1
        }
      }
    }
    
    # Create mega-plot with grid layout and headers
    if (length(all_plots) > 0) {
      mega_plot <- patchwork::wrap_plots(all_plots, 
                                        ncol = length(varnames),  # columns = treatments
                                        nrow = length(outcomes))  # rows = outcomes
      
      # No overall title/subtitle since we use individual panel titles
      
      # Handle shared legend for party plots
      if (is.character(color_scale) && color_scale == "party" && hetvar == "pid3") {
        mega_plot <- mega_plot + 
          patchwork::plot_layout(guides = "collect") &
          theme(legend.position = "bottom", legend.title = element_blank())
      } else {
        mega_plot <- mega_plot + 
          patchwork::plot_layout(guides = "collect") &
          theme(legend.position = "bottom")
      }
      
      # Store mega-plot in results
      for (outcome in names(results)) {
        results[[outcome]][["mega_combined"]] <- mega_plot
        break  # Only need to store once
      }
    }
  }
    
  # Save mega-plot if requested
  # Skip unnumbered (diagnostic) figures when pnas = TRUE
  # Numbered figures (with fig_number) are always saved
  should_save <- save_plots && combine_plots && (!is.null(fig_number) || isFALSE(pnas))
  
  if (should_save) {
    # Check if mega-plot exists
    mega_plot_found <- FALSE
    for (outcome in names(results)) {
      if ("mega_combined" %in% names(results[[outcome]])) {
        fig_prefix <- if (!is.null(fig_number)) paste0("fig_", fig_number, "_") else "fig_"
        filename <- here("output", "pnas", "figures", 
                        paste0(fig_prefix, "proximity_effect_", output_prefix, ".pdf"))
        # Adjust dimensions for mega-plot to fit standard paper
        save_pnas_pdf(results[[outcome]][["mega_combined"]], filename, 
                     width = "double", height = 12)  # Increased height for multiple rows
        mega_plot_found <- TRUE
        break
      }
    }
    
    # Fallback: save individual plots if no mega-plot
    if (!mega_plot_found) {
      for (outcome in names(results)) {
        outcome_suffix <- case_when(
          outcome == "greenproj_bin" ~ "",
          outcome == "credit_biden_bin" ~ "_credit", 
          outcome == "greenbenefit_bin" ~ "_benefit",
          TRUE ~ paste0("_", gsub("_bin$", "", outcome))
        )
        
        for (plot_name in names(results[[outcome]])) {
          if (!grepl("combined", plot_name)) {
            fig_prefix <- if (!is.null(fig_number)) paste0("fig_", fig_number, "_") else "fig_"
            filename <- here("output", "pnas", "figures", 
                            paste0(fig_prefix, "proximity_effect_", output_prefix, outcome_suffix, "_", 
                                  gsub(paste0("_", outcome), "", plot_name), ".pdf"))
            save_pnas_pdf(results[[outcome]][[plot_name]], filename, width = width, height = height)
          }
        }
      }
    }
  }
  
  # Create comprehensive table with all models----
  # Skip unnumbered (diagnostic) tables when pnas = TRUE
  # Numbered tables (with tab_number) are always saved
  should_save_table <- save_plots && !is.null(output_prefix) && (!is.null(tab_number) || isFALSE(pnas))
  
  if (should_save_table) {
    message("Creating heterogeneity analysis table...")
    
    # Build dynamic coefmap for interaction terms
    hetvar_coefmap <- build_interaction_coefmap(varnames, hetvar, hetvar_labels, data)
    
    # Organize models for table
    table_models <- organize_het_models(all_models, varnames, outcomes, outcome_labels, treatment_labels)
    
    # Create the table
    create_heterogeneity_table(
      models = table_models,
      coef_map = hetvar_coefmap,
      hetvar = hetvar,
      hetvar_labels = hetvar_labels,
      treatment_labels = treatment_labels,
      outcome_labels = outcome_labels,
      filename = here("output", "pnas", "tables", paste0(if (!is.null(tab_number)) paste0("tab_", tab_number, "_") else "tab_", "proximity_effect_het_", output_prefix, ".tex")),
      output_prefix = output_prefix,
      table_resize_width = table_resize_width
    )
  }
  
  return(results)
}

# Education-----
college_results <- plot_het_effects(
  data = g,
  hetvar = "college", 
  legend_label = "Education",
  hetvar_labels = c("No college", "College"),
  output_prefix = "college",
  fig_number = "S10",
  tab_number = "S24"
)

# Income-----
income_results <- plot_het_effects(
  data = g,
  hetvar = "income_bin_lab",
  legend_label = "Household Income",
  output_prefix = "income",
  fig_number = "S11",
  tab_number = "S25"
)

# Partisanship-----
pid_results <- plot_het_effects(
  data = g,
  hetvar = "pid3",
  color_scale = "party",
  legend_label = "Partisanship",
  output_prefix = "pid",
  fig_number = "S9",
  tab_number = "S23"
)

# Project status----
# Create indicator for operating manufcturing projects
table(g$status_bgm_d_mfg_any_2y)
g$status_mfg_operating <- ifelse(
  g$status_bgm_d_mfg_any_2y %in% c("Operating", "Operating Partially; Under Construction"), 
  "Operating", "Other")
prop.table(table(g$status_mfg_operating))

# Note: Manufacturing uses status_bgm_2 with d_mfg_q, Renewable uses status_eia with d_re_q
# Need separate analyses since heterogeneity variables differ

# Manufacturing status analysis
status_mfg_results <- plot_het_effects(
  data = g,
  varnames = "d_mfg_any_2y_q",
  treatment_labels = tools::toTitleCase(gsub("_", " ", "manufacturing")), 
  hetvar = "status_mfg_operating",
  legend_label = "Project Status",
  output_prefix = "status_mfg",
  fig_number = "S5",
  tab_number = "S19"
)

# Renewable status analysis  
status_re_results <- plot_het_effects(
  data = g,
  varnames = treat.list[["renewable_energy"]],
  treatment_labels = tools::toTitleCase(gsub("_", " ", "renewable_energy")),
  hetvar = paste0("status_eia_", gsub("_q", "", treat.list[["renewable_energy"]])), 
  legend_label = "Project Status",
  output_prefix = "status_re",
  fig_number = "S6",
  tab_number = "S20"
)

# Sector----

# Manufacturing sector analysis
sector_mfg_results <- plot_het_effects(
  data = g,
  varnames = treat.list[["manufacturing"]],
  treatment_labels = "Manufacturing",
  hetvar = paste0("sector_", gsub("_q", "", treat.list[["manufacturing"]])),
  legend_label = "Sector",
  color_scale = scale_color_viridis_d,
  output_prefix = "mfg_sector",
  table_resize_width = 0.4,
  fig_number = "S8",
  tab_number = "S22"
)

# Renewable technology analysis
table(g$technology_d_re_2y)
g$renewable_moderator <- ifelse(
  g[[paste0("technology_", gsub("_q", "", treat.list[["renewable_energy"]]))]] %in% c("Offshore Wind Turbine", "Onshore Wind Turbine"), "Wind", "Solar")
table(g$renewable_moderator)

sector_re_results <- plot_het_effects(
  data = g,
  varnames = treat.list[["renewable_energy"]],
  treatment_labels = tools::toTitleCase(gsub("_", " ", "renewable_energy")),
  hetvar = "renewable_moderator",
  legend_label = "Technology",
  output_prefix = "re_sector",
  fig_number = "S7",
  tab_number = "S21"
)

# Biden dropped out of the race----
# Biden ended his re-election bid on July 21, 2024
g$biden_dropped_out <- factor(
  ifelse(g$start_date < as.Date("2024-07-21"), "Before", "After"),
  levels = c("Before", "After")
)

# Check distribution across survey timing
message("Survey responses by Biden dropout timing:")
print(table(g$biden_dropped_out))
print(prop.table(table(g$biden_dropped_out)))

# Cross-tab with survey waves
message("\nBreakdown by survey sample:")
print(table(g$sample, g$biden_dropped_out))

biden_dropped_out_results <- plot_het_effects(
  data = g,
  hetvar = "biden_dropped_out",
  outcomes = c("greenproj_bin", "credit_biden_bin"),
  outcome_labels = c("visibility probability", "Biden credit attribution"),
  legend_label = "Survey Timing",
  output_prefix = "biden_dropped_out",
  hetvar_labels = c("Before Biden exit", "After Biden exit"),
  fig_number = "S12",
  tab_number = "S26"
)

